"""
Phase 1: Data Assembly — Pension Fund Home Bias Panel
======================================================
Merges full_panel.csv with OECD pension fund statistics and
IMF CPIS data to test whether aging drives pension fund growth
and cross-border diversification.

Key hypothesis: aging → pension AUM growth → home market insufficient
→ cross-border portfolio diversification. This explains why Z
affects portfolio flows but not FDI (Paper 2).

Output: pension_home_bias/data/processed/pension_panel.csv
Tables: summary_statistics.md
"""

import sys
from pathlib import Path
import time

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/pension_home_bias")
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
RAW_DIR = PROJECT_DIR / "data" / "raw"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

for d in [PROCESSED_DIR, RAW_DIR, TABLES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


def download_wdi_indicator(indicator, name, countries, max_retries=3):
    """Download a WDI indicator via World Bank API."""
    import urllib.request
    import json

    print(f"  Downloading WDI {indicator} ({name}) ...")
    all_rows = []
    chunk_size = 15

    for i in range(0, len(countries), chunk_size):
        chunk = countries[i:i + chunk_size]
        country_str = ';'.join(chunk)
        url = (f"https://api.worldbank.org/v2/country/{country_str}/"
               f"indicator/{indicator}?format=json&per_page=10000&date=1990:2024")

        for attempt in range(max_retries):
            try:
                req = urllib.request.Request(url)
                req.add_header('User-Agent', 'Mozilla/5.0')
                with urllib.request.urlopen(req, timeout=30) as resp:
                    data = json.loads(resp.read().decode())
                if len(data) >= 2 and data[1]:
                    for obs in data[1]:
                        if obs.get('value') is not None:
                            all_rows.append({
                                'iso3': obs['countryiso3code'],
                                'year': int(obs['date']),
                                name: float(obs['value']),
                            })
                break
            except Exception as e:
                if attempt == max_retries - 1:
                    print(f"    Failed chunk {i//chunk_size}: {e}")
                time.sleep(1.5)

        time.sleep(0.5)

    if all_rows:
        df = pd.DataFrame(all_rows)
        df = df.drop_duplicates(subset=['iso3', 'year'])
        print(f"    {name}: {df['iso3'].nunique()} countries, {len(df)} obs")
        return df
    return pd.DataFrame(columns=['iso3', 'year', name])


def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly — Pension Fund Home Bias Panel")
    print("=" * 70)

    # ── [1] Load full_panel.csv ──
    print("\n[1] Loading full_panel.csv ...")
    full_path = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    df = pd.read_csv(full_path)
    df = df[df['year'] <= 2024].copy()
    countries = sorted(df['iso3'].unique())
    print(f"  Full panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # ── [2] Pension data from full_panel ──
    print("\n[2] Checking existing pension data ...")
    pension_vars_existing = ['pension_spending_gdp', 'pension_coverage']
    for v in pension_vars_existing:
        if v in df.columns:
            n = df[v].notna().sum()
            nc = df.loc[df[v].notna(), 'iso3'].nunique()
            print(f"  {v}: {n:,} obs, {nc} countries")

    # ── [3] Download additional WDI financial indicators ──
    print("\n[3] Downloading WDI financial indicators ...")

    raw_cache = RAW_DIR / "wdi_financial.csv"
    if raw_cache.exists():
        print("  Loading cached WDI financial data ...")
        fin_df = pd.read_csv(raw_cache)
    else:
        indicators = {
            'FS.AST.PRVT.GD.ZS': 'domestic_credit_private',  # Domestic credit to private sector (% GDP)
            'CM.MKT.LCAP.GD.ZS': 'stock_market_cap_gdp',     # Market cap of listed companies (% GDP)
            'CM.MKT.TRAD.GD.ZS': 'stock_traded_gdp',         # Stocks traded, total value (% GDP)
            'FD.AST.PRVT.GD.ZS': 'financial_system_deposits', # Financial system deposits (% GDP)
        }

        fin_dfs = []
        for indicator, name in indicators.items():
            idf = download_wdi_indicator(indicator, name, countries)
            if len(idf) > 0:
                fin_dfs.append(idf)

        if fin_dfs:
            fin_df = fin_dfs[0]
            for idf in fin_dfs[1:]:
                fin_df = fin_df.merge(idf, on=['iso3', 'year'], how='outer')
            fin_df.to_csv(raw_cache, index=False)
        else:
            fin_df = pd.DataFrame(columns=['iso3', 'year'])

    # Merge financial indicators
    if len(fin_df) > 0:
        existing = set(df.columns)
        new_cols = [c for c in fin_df.columns if c not in existing and c not in ['iso3', 'year']]
        if new_cols:
            merge_df = fin_df[['iso3', 'year'] + new_cols].drop_duplicates(
                subset=['iso3', 'year'])
            df = df.merge(merge_df, on=['iso3', 'year'], how='left')
            for col in new_cols:
                n = df[col].notna().sum()
                print(f"  {col}: {n:,} obs")

    # ── [4] Construct pension/institutional investor proxies ──
    print("\n[4] Constructing derived variables ...")

    # Gross external assets and liabilities as proxy for cross-border diversification
    if 'gross_assets_gdp' in df.columns:
        df['log_gross_assets'] = np.log(df['gross_assets_gdp'].clip(lower=0.01))
        df.loc[df['gross_assets_gdp'].isna(), 'log_gross_assets'] = np.nan
        print(f"  log_gross_assets: {df['log_gross_assets'].notna().sum():,} obs")

    # Portfolio equity + debt assets as pension allocation proxy
    if 'port_eq_assets_gdp' in df.columns and 'debt_assets_gdp' in df.columns:
        df['portfolio_assets_gdp'] = df['port_eq_assets_gdp'] + df['debt_assets_gdp']
        print(f"  portfolio_assets_gdp: {df['portfolio_assets_gdp'].notna().sum():,} obs")

    # Debt share of external assets (pension tilt toward fixed income)
    if 'debt_assets_gdp' in df.columns and 'gross_assets_gdp' in df.columns:
        df['debt_share_assets'] = df['debt_assets_gdp'] / df['gross_assets_gdp'].clip(lower=0.01)
        df.loc[df['gross_assets_gdp'].isna(), 'debt_share_assets'] = np.nan
        print(f"  debt_share_assets: {df['debt_share_assets'].notna().sum():,} obs")

    # FDI share (should be low if pension-driven)
    if 'fdi_assets_gdp' in df.columns and 'gross_assets_gdp' in df.columns:
        df['fdi_share_assets'] = df['fdi_assets_gdp'] / df['gross_assets_gdp'].clip(lower=0.01)
        df.loc[df['gross_assets_gdp'].isna(), 'fdi_share_assets'] = np.nan
        print(f"  fdi_share_assets: {df['fdi_share_assets'].notna().sum():,} obs")

    # Financial depth (domestic market size — home bias denominator)
    if 'stock_market_cap_gdp' in df.columns and 'domestic_credit_private' in df.columns:
        df['financial_depth'] = df['stock_market_cap_gdp'].fillna(0) + df['domestic_credit_private'].fillna(0)
        print(f"  financial_depth: {df['financial_depth'].notna().sum():,} obs")

    # Home bias proxy: 1 - (gross_assets / financial_depth) — lower = more diversified
    if 'gross_assets_gdp' in df.columns and 'financial_depth' in df.columns:
        df['diversification'] = df['gross_assets_gdp'] / df['financial_depth'].clip(lower=1)
        df.loc[df['financial_depth'].isna() | df['gross_assets_gdp'].isna(), 'diversification'] = np.nan
        print(f"  diversification: {df['diversification'].notna().sum():,} obs")

    # ── [5] Interaction terms ──
    print("\n[5] Interaction terms ...")
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    for zv in demo_vars:
        if zv in df.columns:
            if 'pension_spending_gdp' in df.columns:
                df[f'{zv}_x_pension'] = df[zv] * df['pension_spending_gdp']
            if 'financial_depth' in df.columns:
                df[f'{zv}_x_findepth'] = df[zv] * df['financial_depth']

    # Lagged and differenced demographics
    df = df.sort_values(['iso3', 'year'])
    for zv in demo_vars:
        if zv in df.columns:
            df[f'{zv}_lag5'] = df.groupby('iso3')[zv].shift(5)
            df[f'd_{zv}'] = df.groupby('iso3')[zv].diff()

    if 'old_dep' in df.columns:
        df['oadr_plus20'] = df.groupby('iso3')['old_dep'].shift(-20)

    # ── [6] Regime variables ──
    print("\n[6] Regime variables ...")
    df['oecd'] = df['iso3'].isin(OECD_38).astype(int)

    if 'gdp_pc_ppp' in df.columns:
        def safe_qcut(x):
            try:
                return pd.qcut(x, 3, labels=['low', 'mid', 'high'], duplicates='drop')
            except (ValueError, IndexError):
                return pd.Series(np.nan, index=x.index)
        df['income_group'] = df.groupby('year')['gdp_pc_ppp'].transform(safe_qcut)

    # ── [7] Restrict and summarize ──
    print("\n[7] Panel summary ...")
    df = df[(df['year'] >= 1990) & (df['year'] <= 2024)].copy()

    n_total = len(df)
    n_countries = df['iso3'].nunique()
    print(f"  Total panel: {n_countries} countries, {n_total:,} obs")

    # ── [8] Summary statistics ──
    print("\n[8] Building summary statistics ...")
    key_vars = ['pension_spending_gdp', 'pension_coverage',
                'gross_assets_gdp', 'portfolio_assets_gdp',
                'debt_share_assets', 'fdi_share_assets',
                'financial_depth', 'diversification',
                'domestic_credit_private', 'stock_market_cap_gdp',
                'Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                'ca_gdp', 'nfa_gdp_lag', 'kaopen', 'log_gdp_pc']
    key_vars = [v for v in key_vars if v in df.columns]

    md = ["# Summary Statistics — Pension Fund Home Bias Panel\n"]
    md.append("| Variable | N | Mean | SD | Min | Max |")
    md.append("|---|---|---|---|---|---|")
    for v in key_vars:
        s = df[v].dropna()
        if len(s) > 0:
            md.append(f"| {v} | {len(s):,} | {s.mean():.3f} | {s.std():.3f} "
                      f"| {s.min():.3f} | {s.max():.3f} |")
    out = TABLES_DIR / "summary_statistics.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")

    # ── [9] Save ──
    print("\n[9] Saving pension_panel.csv ...")
    df.to_csv(PROCESSED_DIR / "pension_panel.csv", index=False)
    print(f"  Saved: {PROCESSED_DIR / 'pension_panel.csv'}")
    print(f"  Shape: {df.shape[0]:,} obs x {df.shape[1]} columns")

    print("\n" + "=" * 70)
    print("Phase 1 complete.")
    print("=" * 70)

    return df


if __name__ == "__main__":
    df = main()
